{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "859b5e03-2c41-4350-99f2-6cc6bf39aeb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import random\n",
    "import torch\n",
    "import d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75515768-b285-4ede-aa34-de5edb07208c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = d2l.read_ptb()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ba3a916e-02f1-46d3-bb8a-f090bb963b25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# sentences数: 42069'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f'# sentences数: {len(sentences)}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7cc8c853-e100-4f62-b7b9-a8fb27c6cc77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'vocab size: 6719'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = d2l.Vocab(sentences, min_freq=10) \n",
    "f'vocab size: {len(vocab)}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4cfed082-8ef6-47ff-8467-ff90ac1d5baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "subsampled, counter = d2l.subsample(sentences, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "48ca92ad-9918-463c-89c2-1502c775a628",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAGwCAYAAAC0HlECAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIKUlEQVR4nO3deXhU5f2/8fcQskFIwpZNQ9gKBFkFiRFFkUhQakWtivJTEIRKQcUoAi6ARcViFRRRqlbw20IVFbCCAmlkKfsiWxAQMIoVQqiSBAKEkDy/P2hOGRJgMkxyZjL367rmKpnzzJnPySC5OzM54zDGGAEAAOCCatg9AAAAgC8gmgAAAFxANAEAALiAaAIAAHAB0QQAAOACogkAAMAFRBMAAIALato9QHVRUlKiAwcOqE6dOnI4HHaPAwAAXGCM0dGjRxUXF6caNS78XBLR5CEHDhxQfHy83WMAAAA3/Pjjj7r88ssvuIZo8pA6depIOvNNDw8Pt3kaAADgivz8fMXHx1s/xy/I2Oill14ynTt3NmFhYaZhw4bmtttuM7t27XJac/311xtJTpff/e53Tmt++OEHc8stt5jQ0FDTsGFD8+STT5qioiKnNUuXLjUdO3Y0QUFBplmzZmbGjBll5nnzzTdNQkKCCQ4ONl26dDHr1q1z+Vjy8vKMJJOXl+f6NwAAANiqIj+/bX0j+PLlyzVs2DCtXbtW6enpKioqUs+ePVVQUOC0bvDgwTp48KB1mTRpkrWtuLhYvXv31qlTp7R69Wp98MEHmjlzpsaOHWutycrKUu/evdW9e3dt2bJFI0aM0EMPPaTFixdbaz766COlpaVp3Lhx+vrrr9W+fXulpqYqJyen8r8RAADA6zmM8Z4P7D18+LCioqK0fPlydevWTZJ0ww03qEOHDpoyZUq5t/nyyy/161//WgcOHFB0dLQkafr06Ro1apQOHz6soKAgjRo1SgsXLlRmZqZ1u759+yo3N1eLFi2SJCUlJemqq67Sm2++KenMG7vj4+P1yCOPaPTo0RedPT8/XxEREcrLy+PlOQAAfERFfn571SkH8vLyJEn16tVzun7WrFlq0KCB2rRpozFjxuj48ePWtjVr1qht27ZWMElSamqq8vPztWPHDmtNSkqK0z5TU1O1Zs0aSdKpU6e0adMmpzU1atRQSkqKteZchYWFys/Pd7oAAIDqy2veCF5SUqIRI0aoa9euatOmjXX9fffdp4SEBMXFxWnbtm0aNWqUdu/erblz50qSsrOznYJJkvV1dnb2Bdfk5+frxIkTOnLkiIqLi8tds2vXrnLnnThxop5//vlLO2gAQKUpLi5WUVGR3WPAZoGBgQoICPDIvrwmmoYNG6bMzEytXLnS6fohQ4ZYf27btq1iY2PVo0cP7du3T82aNavqMS1jxoxRWlqa9XXpu+8BAPYyxig7O1u5ubl2jwIvERkZqZiYmEs+j6JXRNPw4cO1YMECrVix4qLnSEhKSpIk7d27V82aNVNMTIzWr1/vtObQoUOSpJiYGOt/S687e014eLhCQ0MVEBCggICActeU7uNcwcHBCg4Odv0gAQBVojSYoqKiVKtWLU447MeMMTp+/Lj1S12xsbGXtD9bo8kYo0ceeUTz5s3TsmXL1KRJk4veZsuWLZL+d+DJycl68cUXlZOTo6ioKElSenq6wsPD1bp1a2vNF1984bSf9PR0JScnS5KCgoLUqVMnZWRkqE+fPpLOvFyYkZGh4cOHe+JQAQBVoLi42Aqm+vXr2z0OvEBoaKgkWZ1wKS/V2RpNw4YN0+zZs/XZZ5+pTp061nuQIiIiFBoaqn379mn27Nm65ZZbVL9+fW3btk2PP/64unXrpnbt2kmSevbsqdatW+v+++/XpEmTlJ2drWeffVbDhg2zngl6+OGH9eabb+qpp57SwIED9dVXX2nOnDlauHChNUtaWpr69++vzp07q0uXLpoyZYoKCgr04IMPVv03BgDgltL3MNWqVcvmSeBNSv8+FBUVXdr7myr3lFEXpnNOWll6KT3x5P79+023bt1MvXr1THBwsGnevLkZOXJkmRNQff/99+bmm282oaGhpkGDBuaJJ54o9+SWHTp0MEFBQaZp06blntxy6tSpplGjRiYoKMh06dLFrF271uVj4eSWAGC/EydOmG+++cacOHHC7lHgRS7096IiP7+96jxNvozzNAGA/U6ePKmsrCw1adJEISEhdo8DL3Ghvxc+e54mAAAAb+UVvz0HAEBlazx64cUXecj3L/eukvsZP3685s+fb/2SlCsu9kkbOD+iCQAAH/Xkk0/qkUceqdBt5s6dq8DAwEqaqHojmgAA8DHGGBUXFyssLExhYWEVuu25H1UG1/GeJgAAvEBhYaEeffRRRUVFKSQkRNdee602bNggSVq2bJkcDoe+/PJLderUScHBwVq5cqXGjx+vDh06WPs4ffq0Hn30UUVGRqp+/foaNWqU+vfvb52DUDrz8tyIESOsrxs3bqyXXnpJAwcOVJ06ddSoUSO98847VXTUvoVnmvzN+IgK3+RooVGvWceVmVOs9Ptrq8tl/z3Hxfg8Dw8HAP7rqaee0qeffqoPPvhACQkJmjRpklJTU7V3715rzejRo/WnP/1JTZs2Vd26dbVs2TKnffzxj3/UrFmzNGPGDCUmJur111/X/Pnz1b179wve96uvvqoJEybo6aef1ieffKKhQ4fq+uuvV8uWLSvjUH0WzzThgs4bTAAAjykoKNDbb7+tV155RTfffLNat26td999V6GhofrLX/5irfvDH/6gm266Sc2aNSv3ZbapU6dqzJgxuv3229WqVSu9+eabioyMvOj933LLLfr973+v5s2ba9SoUWrQoIGWLl3qyUOsFogmnBfBBABVY9++fSoqKlLXrl2t6wIDA9WlSxft3LnTuq5z587n3UdeXp4OHTqkLl26WNcFBASoU6dOF73/0k/ZkCSHw6GYmBjr89rwP0QTykUwAYD3qV27dqXs99zfpnM4HCopKamU+/JlRBPKIJgAoGo1a9ZMQUFBWrVqlXVdUVGRNmzYYH34/MVEREQoOjraevO4dOYDjL/++muPz+uveCM4nBBMAFD1ateuraFDh2rkyJGqV6+eGjVqpEmTJun48eMaNGiQtm7d6tJ+HnnkEU2cOFHNmzdXq1atNHXqVB05ckQOh6OSj8A/EE2wEEwAqrOqOku3u15++WWVlJTo/vvv19GjR9W5c2ctXrxYdevWdXkfo0aNUnZ2th544AEFBARoyJAhSk1NVUAA/557Ah/Y6yE+84G95znlgFvBxCkHAHgZPrDXWUlJiRITE3X33XdrwoQJdo9jG099YC/PNIFnmACgmvjhhx+0ZMkSXX/99SosLNSbb76prKws3XfffXaPVi3wRnA/RzABQPVRo0YNzZw5U1dddZW6du2q7du365///KcSExPtHq1a4JkmP0YwAUD1Eh8f7/QbePAsnmnyUwQTAAAVQzT5IYIJAICKI5r8DMEEAIB7iCY/QzABAOAeosnPEEwAALiHaPIzBBMAVD+NGzfWlClT7B7DLZ6Yffz48erQoYNH5rkQTjngZwgmAH7rPJ+IcK4XVhTquaWFmtA9WM92C3bzvvjEhOqIZ5oAAPgvjwQTqi2iCQAA2R9Mn3zyidq2bavQ0FDVr19fKSkpKigo0A033KARI0Y4re3Tp48GDBjgdN3Ro0d17733qnbt2rrssss0bdo0a5sxRuPHj1ejRo0UHBysuLg4Pfroo9b2v/71r+rcubPq1KmjmJgY3XfffcrJybG2L1u2TA6HQ4sXL1bHjh0VGhqqG2+8UTk5Ofryyy+VmJio8PBw3XfffTp+/Lh1uxtuuEHDhw/X8OHDFRERoQYNGui5557ThT72Njc3Vw899JAaNmyo8PBw3Xjjjdq6davTmpdfflnR0dGqU6eOBg0apJMnT1bkW+02ogkA4PfsDqaDBw/q3nvv1cCBA7Vz504tW7ZMd9xxxwXj4lyvvPKK2rdvr82bN2v06NF67LHHlJ6eLkn69NNPNXnyZP35z3/Wnj17NH/+fLVt29a6bVFRkSZMmKCtW7dq/vz5+v7778tEmXTmvUNvvvmmVq9erR9//FF33323pkyZotmzZ2vhwoVasmSJpk6d6nSbDz74QDVr1tT69ev1+uuv67XXXtN777133uO46667rBjbtGmTrrzySvXo0UO//PKLJGnOnDkaP368XnrpJW3cuFGxsbF66623XP4+XQre0wQA8Gt2B5N0JppOnz6tO+64QwkJCZLkFDWu6Nq1q0aPHi1JatGihVatWqXJkyfrpptu0v79+xUTE6OUlBQFBgaqUaNG6tKli3XbgQMHWn9u2rSp3njjDV111VU6duyYwsLCrG0vvPCCunbtKkkaNGiQxowZo3379qlp06aSpN/+9rdaunSpRo0aZd0mPj5ekydPlsPhUMuWLbV9+3ZNnjxZgwcPLnMMK1eu1Pr165WTk6Pg4DOPxZ/+9CfNnz9fn3zyiYYMGaIpU6Zo0KBBGjRokDXTP//5zyp5tolnmgAAfssbgkmS2rdvrx49eqht27a666679O677+rIkSMV2kdycnKZr3fu3CnpzLM3J06cUNOmTTV48GDNmzdPp0+fttZu2rRJt956qxo1aqQ6dero+uuvlyTt37/faZ/t2rWz/hwdHa1atWpZwVR63dkv60nS1VdfLYfD4TTXnj17VFxcXOYYtm7dqmPHjql+/foKCwuzLllZWdq3b58kaefOnUpKSrrgsVcWogkA4Je8JZgkKSAgQOnp6fryyy/VunVrTZ06VS1btlRWVpZq1KhR5mW6oqKiCu0/Pj5eu3fv1ltvvaXQ0FD9/ve/V7du3VRUVKSCggKlpqYqPDxcs2bN0oYNGzRv3jxJ0qlTp5z2ExgYaP3Z4XA4fV16XUlJSYVmO9uxY8cUGxurLVu2OF12796tkSNHur1fT+HlOQCA3/GmYCrlcDjUtWtXde3aVWPHjlVCQoLmzZunhg0b6uDBg9a64uJiZWZmqnv37k63X7t2bZmvExMTra9DQ0N166236tZbb9WwYcPUqlUrbd++XcYY/fzzz3r55ZcVHx8vSdq4caPHjmvdunVl5vrVr36lgICyp8C58sorlZ2drZo1a6px48bl7i8xMVHr1q3TAw884LTPqkA0AQD8ijcG07p165SRkaGePXsqKipK69at0+HDh5WYmKjatWsrLS1NCxcuVLNmzfTaa68pNze3zD5WrVqlSZMmqU+fPkpPT9fHH3+shQsXSpJmzpyp4uJiJSUlqVatWvrb3/6m0NBQJSQkqKSkREFBQZo6daoefvhhZWZmasKECR47tv379ystLU2/+93v9PXXX2vq1Kl69dVXy12bkpKi5ORk9enTR5MmTVKLFi104MABLVy4ULfffrs6d+6sxx57TAMGDFDnzp3VtWtXzZo1Szt27HB6mbCyEE0AAL/hjcEkSeHh4VqxYoWmTJmi/Px8JSQk6NVXX9XNN9+soqIibd26VQ888IBq1qypxx9/vMyzTJL0xBNPaOPGjXr++ecVHh6u1157TampqZKkyMhIvfzyy0pLS1NxcbHatm2rzz//XPXr15d0JqqefvppvfHGG7ryyiv1pz/9Sb/5zW88cmwPPPCATpw4oS5duiggIECPPfaYhgwZUu5ah8OhL774Qs8884wefPBBHT58WDExMerWrZuio6MlSffcc4/27dunp556SidPntSdd96poUOHavHixR6Z90IcpiK/z4jzys/PV0REhPLy8hQeHm73OOfn4hlxXdsXZ7wF4F1OnjyprKwsNWnSRCEhIXaP4/duuOEGdejQwfaPeLnQ34uK/PzmjeAAAAAuIJrglhdWFNo9AgAAVYr3NKHCSt8T8KzdgwAAvNqyZcvsHsGjeKYJFXL2mygBAPAnRBNc5q2/dQIA5+J3nHA2T/19IJrgEoIJgC8oPUP18ePHbZ4E3qT078O5ZzCvKN7ThIsimAD4ioCAAEVGRlqff1arVi2nzz2DfzHG6Pjx48rJyVFkZGS5ZyGvCKIJF0QwAfA1MTExklTmg2PhvyIjI62/F5eCaMJ5EUwAfJHD4VBsbKyioqIq/MG2qH4CAwMv+RmmUkQTykUwAfB1AQEBHvthCUi8ERzlIJgAACiLaIITggkAgPIRTbAQTAAAnB/RBEkEEwAAF0M0gWACAMAFRJOfI5gAAHAN0eTHCCYAAFxHNPkpggkAgIohmvwQwQQAQMURTX6GYAIAwD1Ek58hmAAAcA/R5GcIJgAA3EM0+RmCCQAA9xBNAAAALiCaAAAAXEA0AQAAuIBoAgAAcAHRBAAA4AKiCQAAwAW2RtPEiRN11VVXqU6dOoqKilKfPn20e/dupzUnT57UsGHDVL9+fYWFhenOO+/UoUOHnNbs379fvXv3Vq1atRQVFaWRI0fq9OnTTmuWLVumK6+8UsHBwWrevLlmzpxZZp5p06apcePGCgkJUVJSktavX+/xYwYAAL7J1mhavny5hg0bprVr1yo9PV1FRUXq2bOnCgoKrDWPP/64Pv/8c3388cdavny5Dhw4oDvuuMPaXlxcrN69e+vUqVNavXq1PvjgA82cOVNjx4611mRlZal3797q3r27tmzZohEjRuihhx7S4sWLrTUfffSR0tLSNG7cOH399ddq3769UlNTlZOTUzXfDAAA4NUcxhhj9xClDh8+rKioKC1fvlzdunVTXl6eGjZsqNmzZ+u3v/2tJGnXrl1KTEzUmjVrdPXVV+vLL7/Ur3/9ax04cEDR0dGSpOnTp2vUqFE6fPiwgoKCNGrUKC1cuFCZmZnWffXt21e5ublatGiRJCkpKUlXXXWV3nzzTUlSSUmJ4uPj9cgjj2j06NFlZi0sLFRhYaH1dX5+vuLj45WXl6fw8PBK+x5dsvERHtxXnuf2BQCADfLz8xUREeHSz2+vek9TXt6ZH8L16tWTJG3atElFRUVKSUmx1rRq1UqNGjXSmjVrJElr1qxR27ZtrWCSpNTUVOXn52vHjh3WmrP3UbqmdB+nTp3Spk2bnNbUqFFDKSkp1ppzTZw4UREREdYlPj7+Ug8fAAB4Ma+JppKSEo0YMUJdu3ZVmzZtJEnZ2dkKCgpSZGSk09ro6GhlZ2dba84OptLtpdsutCY/P18nTpzQf/7zHxUXF5e7pnQf5xozZozy8vKsy48//ujegQMAAJ9Q0+4BSg0bNkyZmZlauXKl3aO4JDg4WMHB/vs5but/KlYXu4cAAKAKecUzTcOHD9eCBQu0dOlSXX755db1MTExOnXqlHJzc53WHzp0SDExMdaac3+brvTri60JDw9XaGioGjRooICAgHLXlO4D/7P+p2Ld9NeCiy8EAKAasTWajDEaPny45s2bp6+++kpNmjRx2t6pUycFBgYqIyPDum737t3av3+/kpOTJUnJycnavn2702+5paenKzw8XK1bt7bWnL2P0jWl+wgKClKnTp2c1pSUlCgjI8NagzNKg6lNVIDdowAAUKVsfXlu2LBhmj17tj777DPVqVPHev9QRESEQkNDFRERoUGDBiktLU316tVTeHi4HnnkESUnJ+vqq6+WJPXs2VOtW7fW/fffr0mTJik7O1vPPvushg0bZr189vDDD+vNN9/UU089pYEDB+qrr77SnDlztHDhQmuWtLQ09e/fX507d1aXLl00ZcoUFRQU6MEHH6z6b4yXOjuYFvWrZfc4AABUKVuj6e2335Yk3XDDDU7Xz5gxQwMGDJAkTZ48WTVq1NCdd96pwsJCpaam6q233rLWBgQEaMGCBRo6dKiSk5NVu3Zt9e/fX3/4wx+sNU2aNNHChQv1+OOP6/XXX9fll1+u9957T6mpqdaae+65R4cPH9bYsWOVnZ2tDh06aNGiRWXeHO6vzg2mOsEOu0cCAKBKedV5mnxZRc7zYCs3ztN03mDiPE0AAB/ns+dpgvfhGSYAAM4gmnBeBBMAAP9DNKFcBBMAAM6IJpRBMAEAUBbRBCcEEwAA5SOaYCGYAAA4P6IJkggmAAAuhmgCwQQAgAuIJj9HMAEA4BqiyY8RTAAAuI5o8lMEEwAAFUM0+SGCCQCAiiOa/AzBBACAe4gmP0MwAQDgHqLJzxBMAAC4h2jyMwQTAADuIZr8DMEEAIB7iCYAAAAXEE0AAAAuIJoAAABcQDQBAAC4gGgCAABwAdEEAADgAqIJAADABUQTAACAC4gmAAAAFxBNAAAALiCa4JajhcbuEQAAqFJEEyrsaKFRr1nH7R4DAIAqRTShQkqDKTOn2O5RAACoUkQTXHZ2MKXfX9vucQAAqFJEE1xybjB1uSzA7pEAAKhSRBMuimACAIBowkUQTAAAnEE04bwIJgAA/odoQrkIJgAAnBFNKINgAgCgLKIJTggmAADKRzTBQjABAHB+RBMkEUwAAFwM0QSCCQAAFxBNfo5gAgDANUSTHyOYAABwHdHkpwgmAAAqhmjyQwQTAAAVRzT5GYIJAAD3EE1+hmACAMA9RJOfIZgAAHAP0eRnCCYAANxDNPkZggkAAPcQTQAAAC4gmgAAAFxQ0+4BAJeMj3DrZut/KtZNfy1Qm6gALepXS3WCHdL4PA8PBwDwBzzThGqr3GACAMBNRBOqJYIJAOBpRBOqHYIJAFAZiCZUKwQTAKCyEE2oNggmAEBlIppQLRBMAIDKRjTB5xFMAICqQDTBpxFMAICqQjTBLS+sKLR7BIIJAFClbI2mFStW6NZbb1VcXJwcDofmz5/vtH3AgAFyOBxOl169ejmt+eWXX9SvXz+Fh4crMjJSgwYN0rFjx5zWbNu2Tdddd51CQkIUHx+vSZMmlZnl448/VqtWrRQSEqK2bdvqiy++8PjxVhcvrCjUc0vtjSaCCQBQ1WyNpoKCArVv317Tpk0775pevXrp4MGD1uXvf/+70/Z+/fppx44dSk9P14IFC7RixQoNGTLE2p6fn6+ePXsqISFBmzZt0iuvvKLx48frnXfesdasXr1a9957rwYNGqTNmzerT58+6tOnjzIzMz1/0D6uNJgmdA+2bQaCCQBgB4cxxtg9hCQ5HA7NmzdPffr0sa4bMGCAcnNzyzwDVWrnzp1q3bq1NmzYoM6dO0uSFi1apFtuuUX//ve/FRcXp7ffflvPPPOMsrOzFRQUJEkaPXq05s+fr127dkmS7rnnHhUUFGjBggXWvq+++mp16NBB06dPL/e+CwsLVVj4v2db8vPzFR8fr7y8PIWHh1/Kt6JyufkZbpJzMD3bLbhqP8Ptv3N7JJj47DkAwH/l5+crIiLCpZ/fXv+epmXLlikqKkotW7bU0KFD9fPPP1vb1qxZo8jISCuYJCklJUU1atTQunXrrDXdunWzgkmSUlNTtXv3bh05csRak5KS4nS/qampWrNmzXnnmjhxoiIiIqxLfHy8R47XW5UJJhvwDBMAwE5eHU29evXS//3f/ykjI0N//OMftXz5ct18880qLi6WJGVnZysqKsrpNjVr1lS9evWUnZ1trYmOjnZaU/r1xdaUbi/PmDFjlJeXZ11+/PHHSztYL0YwAQAg1bR7gAvp27ev9ee2bduqXbt2atasmZYtW6YePXrYOJkUHBys4GD73tdTVbwhmCQRTAAA23n1M03natq0qRo0aKC9e/dKkmJiYpSTk+O05vTp0/rll18UExNjrTl06JDTmtKvL7amdLu/8pZgkkQwAQBs51PR9O9//1s///yzYmNjJUnJycnKzc3Vpk2brDVfffWVSkpKlJSUZK1ZsWKFioqKrDXp6elq2bKl6tata63JyMhwuq/09HQlJydX9iF5LW8KJkkEEwDAdrZG07Fjx7RlyxZt2bJFkpSVlaUtW7Zo//79OnbsmEaOHKm1a9fq+++/V0ZGhm677TY1b95cqampkqTExET16tVLgwcP1vr167Vq1SoNHz5cffv2VVxcnCTpvvvuU1BQkAYNGqQdO3boo48+0uuvv660tDRrjscee0yLFi3Sq6++ql27dmn8+PHauHGjhg8fXuXfE2/gbcEkiWACANjO1mjauHGjOnbsqI4dO0qS0tLS1LFjR40dO1YBAQHatm2bfvOb36hFixYaNGiQOnXqpH/9619O7yWaNWuWWrVqpR49euiWW27Rtdde63QOpoiICC1ZskRZWVnq1KmTnnjiCY0dO9bpXE7XXHONZs+erXfeeUft27fXJ598ovnz56tNmzZV983wEt4YTAAAeAOvOU+Tr6vIeR5sdYHzNFU4mGw4T5Nn9sV5mgAAZ1Sr8zShavAMEwAAF0Y0gWACAMAFRJOfI5gAAHAN0eTHCCYAAFxHNPkpggkAgIohmvwQwQQAQMURTX6GYAIAwD1Ek58hmAAAcA/R5Gf8PZiOFnIuVwCAe9yKphtvvFG5ubllrs/Pz9eNN954qTOhEvl7MPWaddzuMQAAPsqtaFq2bJlOnTpV5vqTJ0/qX//61yUPBXhaaTBl5hTbPQoAwEfVrMjibdu2WX/+5ptvlJ2dbX1dXFysRYsW6bLLLvPcdIAHnB1M6ffXtnscAICPqlA0dejQQQ6HQw6Ho9yX4UJDQzV16lSPDQdcqnODqctlAXaPBADwURWKpqysLBlj1LRpU61fv14NGza0tgUFBSkqKkoBAfxQgncgmAAAnlShaEpISJAklZSUVMowgKcQTAAAT6tQNJ1tz549Wrp0qXJycspE1NixYy95MMBdBBMAoDK4FU3vvvuuhg4dqgYNGigmJkYOh8Pa5nA4iCbYhmACAFQWt6LphRde0IsvvqhRo0Z5eh7AbQQTAKAyuXWepiNHjuiuu+7y9CyA2wgmAEBlcyua7rrrLi1ZssTTswBuIZgAAFXBrZfnmjdvrueee05r165V27ZtFRgY6LT90Ucf9chwwMUQTACAquJWNL3zzjsKCwvT8uXLtXz5cqdtDoeDaPID638qVhebZyCYAABVya1oysrK8vQc8CHrfyrWTX8tUN679s1AMAEAqppb72mC/yoNpjZR9kUKwQQAsINbzzQNHDjwgtvff/99t4aBdzs7mBb1q2XLDAQTAMAubkXTkSNHnL4uKipSZmamcnNzy/0gX/i+c4OpTrDj4jfyMIIJAGAnt6Jp3rx5Za4rKSnR0KFD1axZs0seCt7FH4Op8eiFHtnP9y/39sh+AAD289h7mmrUqKG0tDRNnjzZU7uEF/CGYJLEM0wAANt59I3g+/bt0+nTpz25S9jIW4JJEsEEALCdWy/PpaWlOX1tjNHBgwe1cOFC9e/f3yODwV7eFEySCCYAgO3ciqbNmzc7fV2jRg01bNhQr7766kV/sw7ez9uCSRLBBACwnVvRtHTpUk/PAS/hjcEEAIA3cCuaSh0+fFi7d++WJLVs2VINGzb0yFCwB8EEAMD5ufVG8IKCAg0cOFCxsbHq1q2bunXrpri4OA0aNEjHjx/39IyoAgQTAAAX5lY0paWlafny5fr888+Vm5ur3NxcffbZZ1q+fLmeeOIJT8+ISkYwAQBwcW69PPfpp5/qk08+0Q033GBdd8sttyg0NFR333233n77bU/Nh0pGMAEA4Bq3nmk6fvy4oqOjy1wfFRXFy3M+hGACAMB1bkVTcnKyxo0bp5MnT1rXnThxQs8//7ySk5M9NhwqD8EEAEDFuPXy3JQpU9SrVy9dfvnlat++vSRp69atCg4O1pIlSzw6IDyPYAIAoOLciqa2bdtqz549mjVrlnbt2iVJuvfee9WvXz+FhoZ6dEB4FsEEAIB73IqmiRMnKjo6WoMHD3a6/v3339fhw4c1atQojwwHzyOYAABwj1vvafrzn/+sVq1albn+iiuu0PTp0y95KFQeggkAAPe4FU3Z2dmKjY0tc33Dhg118ODBSx4Klcffg+mFFYV2jwAA8FFuRVN8fLxWrVpV5vpVq1YpLi7ukodC5fH3YHpuKdEEAHCPW+9pGjx4sEaMGKGioiLdeOONkqSMjAw99dRTnBEcXqk0mCZ0D7Z7FACAj3IrmkaOHKmff/5Zv//973Xq1ClJUkhIiEaNGqUxY8Z4dEDgUp0dTM92I5oAAO5xK5ocDof++Mc/6rnnntPOnTsVGhqqX/3qVwoO5gcSvAvBBADwFLeiqVRYWJiuuuoqT80CeBTBBADwJLfeCA54O4IJAOBpRBOqHYIJAFAZiCZUKwQTAKCyXNJ7moCLaTx6oUf2833IxdcQTACAysQzTagWCCYAQGUjmuDzCCYAQFUgmuCWo4XG7hEkEUwAgKpDNKHCjhYa9Zp13O4xCCYAQJUimlAhpcGUmVNs6xwEEwCgqhFNcNnZwZR+f23b5iCYAAB2IJrgknODqctlAbbMQTABAOxCNOGiCCYAAIgmXATBBADAGbZG04oVK3TrrbcqLi5ODodD8+fPd9pujNHYsWMVGxur0NBQpaSkaM+ePU5rfvnlF/Xr10/h4eGKjIzUoEGDdOzYMac127Zt03XXXaeQkBDFx8dr0qRJZWb5+OOP1apVK4WEhKht27b64osvPH68vsZbgkkSwQQAsJ2t0VRQUKD27dtr2rRp5W6fNGmS3njjDU2fPl3r1q1T7dq1lZqaqpMnT1pr+vXrpx07dig9PV0LFizQihUrNGTIEGt7fn6+evbsqYSEBG3atEmvvPKKxo8fr3feecdas3r1at17770aNGiQNm/erD59+qhPnz7KzMysvIP3ct4UTJIIJgCA7RzGGK84S6HD4dC8efPUp08fSWeeZYqLi9MTTzyhJ598UpKUl5en6OhozZw5U3379tXOnTvVunVrbdiwQZ07d5YkLVq0SLfccov+/e9/Ky4uTm+//baeeeYZZWdnKygoSJI0evRozZ8/X7t27ZIk3XPPPSooKNCCBQusea6++mp16NBB06dPd2n+/Px8RUREKC8vT+Hh4Z76tnje+IiLLnE5mMbnXXRfnvvsufs8sh9JVTv3y709sh8AQOWoyM9vr31PU1ZWlrKzs5WSkmJdFxERoaSkJK1Zs0aStGbNGkVGRlrBJEkpKSmqUaOG1q1bZ63p1q2bFUySlJqaqt27d+vIkSPWmrPvp3RN6f2Up7CwUPn5+U6X6sDbnmECAMBbeG00ZWdnS5Kio6Odro+Ojra2ZWdnKyoqyml7zZo1Va9ePac15e3j7Ps435rS7eWZOHGiIiIirEt8fHxFD9HrEEwAAJyf10aTtxszZozy8vKsy48//mj3SJeEYAIA4MK8NppiYmIkSYcOHXK6/tChQ9a2mJgY5eTkOG0/ffq0fvnlF6c15e3j7Ps435rS7eUJDg5WeHi408VXEUwAAFyc10ZTkyZNFBMTo4yMDOu6/Px8rVu3TsnJyZKk5ORk5ebmatOmTdaar776SiUlJUpKSrLWrFixQkVFRdaa9PR0tWzZUnXr1rXWnH0/pWtK76c6I5gAAHCNrdF07NgxbdmyRVu2bJF05s3fW7Zs0f79++VwODRixAi98MIL+sc//qHt27frgQceUFxcnPUbdomJierVq5cGDx6s9evXa9WqVRo+fLj69u2ruLg4SdJ9992noKAgDRo0SDt27NBHH32k119/XWlpadYcjz32mBYtWqRXX31Vu3bt0vjx47Vx40YNHz68qr8lVYpgAgDAdTXtvPONGzeqe/fu1telIdO/f3/NnDlTTz31lAoKCjRkyBDl5ubq2muv1aJFixQSEmLdZtasWRo+fLh69OihGjVq6M4779Qbb7xhbY+IiNCSJUs0bNgwderUSQ0aNNDYsWOdzuV0zTXXaPbs2Xr22Wf19NNP61e/+pXmz5+vNm3aVMF3wR4EEwAAFeM152nydb50niaPBRPnabooztMEAN6tWpynCZWDZ5gAAHAP0eRnCCYAANxDNPkZggkAAPcQTX7G34Np/U/Fdo8AAPBRRJOf8fdguumvBXaPAQDwUUQT/EJpMLWJ8t9oBABcGqIJ1d7ZwbSoXy27xwEA+CiiCdXaucFUJ9hh90gAAB9l6xnBgcrkj8HESTkBoPLwTBOqJX8MJgBA5SKaUO0QTACAykA0oVohmAAAlYVoQrVBMAEAKhPRhGqBYAIAVDaiCT6PYAIAVAWiCT6NYAIAVBWiCW55YUWh3SMQTACAKkU0ocJeWFGo55baG00EEwCgqhFNqJDSYJrQPdi2GQgmAIAdiCa47OxgerabPdFEMAEA7EI0wSUEEwDA3xFNuCiCCQAAogkX4Q3BJIlgAgDYjmjCeXlLMEkimAAAtiOaUC5vCiZJBBMAwHZEE8rwtmCSRDABAGxHNMGJNwYTAADegGiChWACAOD8iCZIIpgAALgYogkEEwAALiCa/BzBBACAa4gmP0YwAQDgOqLJTxFMAABUDNHkhwgmAAAqjmjyMwQTAADuIZr8DMEEAIB7iCY/4+/BdLTQ2D0CAMBHEU1+xt+Dqdes43aPAQDwUUQT/EJpMGXmFNs9CgDARxFNqPbODqb0+2vbPQ4AwEcRTajWzg2mLpcF2D0SAMBHEU2otggmAIAnEU2olggmAICnEU2odggmAEBlIJpQrRBMAIDKQjSh2iCYAACViWhCtUAwAQAqG9EEn0cwAQCqAtEEn0YwAQCqCtEEt6z/yf6PIyGYAABViWhCha3/qVg3/bXA1hkIJgBAVSOaUCGlwdQmyr5IIZgAAHYgmuCys4NpUb9atsxAMAEA7EI0wSXnBlOdYEeVz0AwAQDsRDThoggmAACIJlyENwSTJIIJAGA7ognn5S3BJIlgAgDYjmhCubwpmCQRTAAA2xFNKMPbgkkSwQQAsB3RBCfeGEwAAHgDogkWggkAgPMjmiCJYAIA4GK8OprGjx8vh8PhdGnVqpW1/eTJkxo2bJjq16+vsLAw3XnnnTp06JDTPvbv36/evXurVq1aioqK0siRI3X69GmnNcuWLdOVV16p4OBgNW/eXDNnzqyKw/MaBBMAABfn1dEkSVdccYUOHjxoXVauXGlte/zxx/X555/r448/1vLly3XgwAHdcccd1vbi4mL17t1bp06d0urVq/XBBx9o5syZGjt2rLUmKytLvXv3Vvfu3bVlyxaNGDFCDz30kBYvXlylx2kXggkAANfUtHuAi6lZs6ZiYmLKXJ+Xl6e//OUvmj17tm688UZJ0owZM5SYmKi1a9fq6quv1pIlS/TNN9/on//8p6Kjo9WhQwdNmDBBo0aN0vjx4xUUFKTp06erSZMmevXVVyVJiYmJWrlypSZPnqzU1NQqPdaqRjABAOA6r3+mac+ePYqLi1PTpk3Vr18/7d+/X5K0adMmFRUVKSUlxVrbqlUrNWrUSGvWrJEkrVmzRm3btlV0dLS1JjU1Vfn5+dqxY4e15ux9lK4p3cf5FBYWKj8/3+niSwgmAAAqxqujKSkpSTNnztSiRYv09ttvKysrS9ddd52OHj2q7OxsBQUFKTIy0uk20dHRys7OliRlZ2c7BVPp9tJtF1qTn5+vEydOnHe2iRMnKiIiwrrEx8df6uFWGYIJAICK8+qX526++Wbrz+3atVNSUpISEhI0Z84chYaG2jiZNGbMGKWlpVlf5+fn+0Q4EUwAALjHq59pOldkZKRatGihvXv3KiYmRqdOnVJubq7TmkOHDlnvgYqJiSnz23SlX19sTXh4+AXDLDg4WOHh4U4XX0AwAQDgHp+KpmPHjmnfvn2KjY1Vp06dFBgYqIyMDGv77t27tX//fiUnJ0uSkpOTtX37duXk5Fhr0tPTFR4ertatW1trzt5H6ZrSfVQ3BBMAAO7x6mh68skntXz5cn3//fdavXq1br/9dgUEBOjee+9VRESEBg0apLS0NC1dulSbNm3Sgw8+qOTkZF199dWSpJ49e6p169a6//77tXXrVi1evFjPPvushg0bpuDgYEnSww8/rO+++05PPfWUdu3apbfeektz5szR448/buehVxp/D6YXVhTaPQIAwEd59Xua/v3vf+vee+/Vzz//rIYNG+raa6/V2rVr1bBhQ0nS5MmTVaNGDd15550qLCxUamqq3nrrLev2AQEBWrBggYYOHark5GTVrl1b/fv31x/+8AdrTZMmTbRw4UI9/vjjev3113X55Zfrvffeq7anG/D3YHpuaaGetXsQAIBP8upo+vDDDy+4PSQkRNOmTdO0adPOuyYhIUFffPHFBfdzww03aPPmzW7NCN9QGkwTugfbPQoAwEd59ctzgCecHUzPdiOaAADuIZpQrRFMAABPIZpQbRFMAABPIppQLRFMAABPI5pQ7RBMAIDKQDShWiGYAACVhWhCtUEwAQAqE9GEaoFgAgBUNq8+uSX+p/HohR7Zz/chHtmNVyGYAABVgWea4JajhcbuESQRTACAqkM0ocKOFhr1mnXc7jEIJgBAleLlOVRIaTBl5hTbOgfBVL147OXnl3t7ZD8AUB6eaYLLzg6m9Ptr2zYHwQQAsAPRBJecG0xdLguwZQ6CCQBgF6IJF0UwAQBANOEiCCYAAM4gmnBe3hJMkggmAIDtiCaUy5uCSRLBBACwHdGEMrwtmCQRTAAA2xFNcOKNwQQAgDcgmmAhmAAAOD+iCZIIJgAALoZoAsEEAIALiCY/RzABAOAaosmPEUwAALiOaPJTBBMAABVDNPkhggkAgIojmvwMwQQAgHuIJj9DMAEA4B6iyc8QTAAAuIdo8jP+Hkzrfyq2ewQAgI8imvyMvwfTTX8tsHsMAICPIprgF0qDqU2U/0YjAODSEE2o9s4OpkX9atk9DgDARxFNqNbODaY6wQ67RwIA+CiiCdUWwQQA8CSiCdUSwQQA8DSiCdUOwQQAqAxEE6oVggkAUFlq2j0A4CleGUzjIzy4rzzP7QsAUGE804RqwSuDCQBQrRBN8HkEEwCgKhBN8GkEEwCgqhBNcMsLKwrtHoFgAgBUKaIJFfbCikI9t9TeaCKYAABVjWhChZQG04TuwbbNQDABAOxANMFlZwfTs93siSaCCQBgF87TBJcQTEBZjUcv9Mh+vn+5t0f2A6By8UwTLopgAgCAaMJFeEMwSSKYAAC2I5pwXt4STJIIJgCA7YgmlMubgkmS3weTN5wXCwD8HdGEMrwtmCT5fTDZfV4sAADRhHN4YzD5M284LxYA4AyiCRaCybvweACAdyGaIIkf0N6GxwMAvA/RBH5AexkeDwDwTkSTn+MHtHfh8QAA70U0+TF+QHsXHg8A8G5Ek5/iB7R34fEAAO9HNPkhfkB7Fx4PAPANRJOf4Qe0d+HxAADfUdPuAbzNtGnT9Morryg7O1vt27fX1KlT1aVLF7vH8hh+QHsPggl2aTx6oUf28/3LvT2yH8BX8EzTWT766COlpaVp3Lhx+vrrr9W+fXulpqYqJyfH7tE8xt9/QB8tNHaPIIlgAgBfxDNNZ3nttdc0ePBgPfjgg5Kk6dOna+HChXr//fc1evRom6fzDH/+AX200KjXrONaNdHeOXwimMZHeHBfeZ7bFwDYiGj6r1OnTmnTpk0aM2aMdV2NGjWUkpKiNWvWlFlfWFiowsL/fYhqXt6ZHwz5+fmVMl9J4XGP7Cff4cFnWlw4Vm+Z+2ih0R1zjmvn4RKXHqPKmnvSqkK9+K9Teua6ID2aFKT8ijzz5UPfb+edVeHclfTf3/kwd9XO3WbcYo/sJ/P5VI/sx1W+Ore/KP17bIwL/+4ZGGOM+emnn4wks3r1aqfrR44cabp06VJm/bhx44wkLly4cOHChUs1uPz4448XbQWeaXLTmDFjlJaWZn1dUlKiX375RfXr15fD4XB7v/n5+YqPj9ePP/6o8PBwT4zqVTg+31bdj0+q/sfI8fk2js/zjDE6evSo4uLiLrqWaPqvBg0aKCAgQIcOHXK6/tChQ4qJiSmzPjg4WMHBzu9HiYyM9Ng84eHh1fI/iFIcn2+r7scnVf9j5Ph8G8fnWRERES6t47fn/isoKEidOnVSRkaGdV1JSYkyMjKUnJxs42QAAMAb8EzTWdLS0tS/f3917txZXbp00ZQpU1RQUGD9Nh0AAPBfRNNZ7rnnHh0+fFhjx45Vdna2OnTooEWLFik6OrrKZggODta4cePKvPRXXXB8vq26H59U/Y+R4/NtHJ+9HMa48jt2AAAA/o33NAEAALiAaAIAAHAB0QQAAOACogkAAMAFRJOXmTZtmho3bqyQkBAlJSVp/fr1do/klhUrVujWW29VXFycHA6H5s+f77TdGKOxY8cqNjZWoaGhSklJ0Z49e+wZ1g0TJ07UVVddpTp16igqKkp9+vTR7t27ndacPHlSw4YNU/369RUWFqY777yzzMlTvdXbb7+tdu3aWSeYS05O1pdffmlt9+VjO9fLL78sh8OhESNGWNf5+vGNHz9eDofD6dKqVStru68fnyT99NNP+n//7/+pfv36Cg0NVdu2bbVx40Zruy//G9O4ceMyj5/D4dCwYcMk+f7jV1xcrOeee05NmjRRaGiomjVrpgkTJjh99pvXPn6X/qlt8JQPP/zQBAUFmffff9/s2LHDDB482ERGRppDhw7ZPVqFffHFF+aZZ54xc+fONZLMvHnznLa//PLLJiIiwsyfP99s3brV/OY3vzFNmjQxJ06csGfgCkpNTTUzZswwmZmZZsuWLeaWW24xjRo1MseOHbPWPPzwwyY+Pt5kZGSYjRs3mquvvtpcc801Nk7tun/84x9m4cKF5ttvvzW7d+82Tz/9tAkMDDSZmZnGGN8+trOtX7/eNG7c2LRr18489thj1vW+fnzjxo0zV1xxhTl48KB1OXz4sLXd14/vl19+MQkJCWbAgAFm3bp15rvvvjOLFy82e/futdb48r8xOTk5To9denq6kWSWLl1qjPH9x+/FF1809evXNwsWLDBZWVnm448/NmFhYeb111+31njr40c0eZEuXbqYYcOGWV8XFxebuLg4M3HiRBununTnRlNJSYmJiYkxr7zyinVdbm6uCQ4ONn//+99tmPDS5eTkGElm+fLlxpgzxxMYGGg+/vhja83OnTuNJLNmzRq7xrwkdevWNe+99161ObajR4+aX/3qVyY9Pd1cf/31VjRVh+MbN26cad++fbnbqsPxjRo1ylx77bXn3V7d/o157LHHTLNmzUxJSUm1ePx69+5tBg4c6HTdHXfcYfr162eM8e7Hj5fnvMSpU6e0adMmpaSkWNfVqFFDKSkpWrNmjY2TeV5WVpays7OdjjUiIkJJSUk+e6x5eXmSpHr16kmSNm3apKKiIqdjbNWqlRo1auRzx1hcXKwPP/xQBQUFSk5OrjbHNmzYMPXu3dvpOKTq89jt2bNHcXFxatq0qfr166f9+/dLqh7H949//EOdO3fWXXfdpaioKHXs2FHvvvuutb06/Rtz6tQp/e1vf9PAgQPlcDiqxeN3zTXXKCMjQ99++60kaevWrVq5cqVuvvlmSd79+HFGcC/xn//8R8XFxWXOPh4dHa1du3bZNFXlyM7OlqRyj7V0my8pKSnRiBEj1LVrV7Vp00bSmWMMCgoq8yHOvnSM27dvV3Jysk6ePKmwsDDNmzdPrVu31pYtW3z+2D788EN9/fXX2rBhQ5lt1eGxS0pK0syZM9WyZUsdPHhQzz//vK677jplZmZWi+P77rvv9PbbbystLU1PP/20NmzYoEcffVRBQUHq379/tfo3Zv78+crNzdWAAQMkVY+/n6NHj1Z+fr5atWqlgIAAFRcX68UXX1S/fv0keffPCKIJuETDhg1TZmamVq5cafcoHtWyZUtt2bJFeXl5+uSTT9S/f38tX77c7rEu2Y8//qjHHntM6enpCgkJsXucSlH6/9glqV27dkpKSlJCQoLmzJmj0NBQGyfzjJKSEnXu3FkvvfSSJKljx47KzMzU9OnT1b9/f5un86y//OUvuvnmmxUXF2f3KB4zZ84czZo1S7Nnz9YVV1yhLVu2aMSIEYqLi/P6x4+X57xEgwYNFBAQUOY3IA4dOqSYmBibpqocpcdTHY51+PDhWrBggZYuXarLL7/cuj4mJkanTp1Sbm6u03pfOsagoCA1b95cnTp10sSJE9W+fXu9/vrrPn9smzZtUk5Ojq688krVrFlTNWvW1PLly/XGG2+oZs2aio6O9unjK09kZKRatGihvXv3+vzjJ0mxsbFq3bq103WJiYnWS5DV5d+YH374Qf/85z/10EMPWddVh8dv5MiRGj16tPr27au2bdvq/vvv1+OPP66JEydK8u7Hj2jyEkFBQerUqZMyMjKs60pKSpSRkaHk5GQbJ/O8Jk2aKCYmxulY8/PztW7dOp85VmOMhg8frnnz5umrr75SkyZNnLZ36tRJgYGBTse4e/du7d+/32eO8VwlJSUqLCz0+WPr0aOHtm/fri1btliXzp07q1+/ftafffn4ynPs2DHt27dPsbGxPv/4SVLXrl3LnOLj22+/VUJCgqTq8W+MJM2YMUNRUVHq3bu3dV11ePyOHz+uGjWc8yMgIEAlJSWSvPzxs/Vt6HDy4YcfmuDgYDNz5kzzzTffmCFDhpjIyEiTnZ1t92gVdvToUbN582azefNmI8m89tprZvPmzeaHH34wxpz5ddLIyEjz2WefmW3btpnbbrvNK36d1FVDhw41ERERZtmyZU6/Gnz8+HFrzcMPP2waNWpkvvrqK7Nx40aTnJxskpOTbZzadaNHjzbLly83WVlZZtu2bWb06NHG4XCYJUuWGGN8+9jKc/Zvzxnj+8f3xBNPmGXLlpmsrCyzatUqk5KSYho0aGBycnKMMb5/fOvXrzc1a9Y0L774otmzZ4+ZNWuWqVWrlvnb3/5mrfH1f2OKi4tNo0aNzKhRo8ps8/XHr3///uayyy6zTjkwd+5c06BBA/PUU09Za7z18SOavMzUqVNNo0aNTFBQkOnSpYtZu3at3SO5ZenSpUZSmUv//v2NMWd+pfS5554z0dHRJjg42PTo0cPs3r3b3qEroLxjk2RmzJhhrTlx4oT5/e9/b+rWrWtq1aplbr/9dnPw4EH7hq6AgQMHmoSEBBMUFGQaNmxoevToYQWTMb59bOU5N5p8/fjuueceExsba4KCgsxll11m7rnnHqdzGPn68RljzOeff27atGljgoODTatWrcw777zjtN3X/41ZvHixkVTuzL7++OXn55vHHnvMNGrUyISEhJimTZuaZ555xhQWFlprvPXxcxhz1ik4AQAAUC7e0wQAAOACogkAAMAFRBMAAIALiCYAAAAXEE0AAAAuIJoAAABcQDQBAAC4gGgCAABwAdEEoMoMGDBAffr0sXsMAHAL0QT4ucOHDysoKEgFBQUqKipS7dq1rU+LPx/ixzvNnDlTkZGRdo8BVFtEE+Dn1qxZo/bt26t27dr6+uuvVa9ePTVq1MjusXzaqVOn7B4BQCUgmgA/t3r1anXt2lWStHLlSuvP5zN+/Hh98MEH+uyzz+RwOORwOLRs2TJJ0vbt23XjjTcqNDRU9evX15AhQ3Ts2LHz7mvDhg1q2LCh/vjHP0qScnNz9dBDD6lhw4YKDw/XjTfeqK1btzrdd4cOHfTXv/5VjRs3VkREhPr27aujR49aaz755BO1bdvWmiElJUUFBQXl3v+yZcvkcDi0cOFCtWvXTiEhIbr66quVmZnptG7lypW67rrrFBoaqvj4eD366KNO+2zcuLEmTJigBx54QOHh4RoyZEi593ex2d577z0lJiYqJCRErVq10ltvvWVt+/777+VwODR37lx1795dtWrVUvv27bVmzRrrWB588EHl5eVZj8v48eMlSYWFhXryySd12WWXqXbt2kpKSrIeM+l/z1AtXrxYiYmJCgsLU69evXTw4EGn+d9//31dccUVCg4OVmxsrIYPH25tu9hjB1QLdn9iMICq98MPP5iIiAgTERFhAgMDTUhIiImIiDBBQUEmODjYREREmKFDh5Z726NHj5q7777b9OrVyxw8eNAcPHjQFBYWmmPHjpnY2Fhzxx13mO3bt5uMjAzTpEkT079/f+u2/fv3N7fddpsxxpiMjAwTERFh/vznP1vbU1JSzK233mo2bNhgvv32W/PEE0+Y+vXrm59//tkYY8y4ceNMWFiYdR8rVqwwMTEx5umnnzbGGHPgwAFTs2ZN89prr5msrCyzbds2M23aNHP06NFyj2Xp0qVGkklMTDRLliwx27ZtM7/+9a9N48aNzalTp4wxxuzdu9fUrl3bTJ482Xz77bdm1apVpmPHjmbAgAHWfhISEkx4eLj505/+ZPbu3Wv27t1b5r4uNtvf/vY3Exsbaz799FPz3XffmU8//dTUq1fPzJw50xhjTFZWlpFkWrVqZRYsWGB2795tfvvb35qEhARTVFRkCgsLzZQpU0x4eLj1uJTu+6GHHjLXXHONWbFihdm7d6955ZVXTHBwsPn222+NMcbMmDHDBAYGmpSUFLNhwwazadMmk5iYaO677z5r/rfeesuEhISYKVOmmN27d5v169ebyZMnu/zYAdUB0QT4oaKiIpOVlWW2bt1qAgMDzdatW83evXtNWFiYWb58ucnKyjKHDx8+7+3Pjp9S77zzjqlbt645duyYdd3ChQtNjRo1THZ2ttPt5s6da8LCwsyHH35orf3Xv/5lwsPDzcmTJ53226xZMyusxo0bZ2rVqmXy8/Ot7SNHjjRJSUnGGGM2bdpkJJnvv//epe9DaTSdPcfPP/9sQkNDzUcffWSMMWbQoEFmyJAhTrf717/+ZWrUqGFOnDhhjDkTTX369LngfV1stmbNmpnZs2c7XTdhwgSTnJxsjPlfNL333nvW9h07dhhJZufOncaYM/ETERHhtI8ffvjBBAQEmJ9++snp+h49epgxY8ZYt5PkFHvTpk0z0dHR1tdxcXHmmWeeKXd2Vx47oDqoad9zXADsUrNmTTVu3Fhz5szRVVddpXbt2mnVqlWKjo5Wt27d3Nrnzp07rfdGleratatKSkq0e/duRUdHS5LWrVunBQsW6JNPPnF6M/nWrVt17Ngx1a9f32m/J06c0L59+6yvGzdurDp16lhfx8bGKicnR5LUvn179ejRQ23btlVqaqp69uyp3/72t6pbt+4FZ09OTrb+XK9ePbVs2VI7d+605tq2bZtmzZplrTHGqKSkRFlZWUpMTJQkde7c+YL3caHZCgoKtG/fPg0aNEiDBw+2bnP69GlFREQ47addu3ZOxy5JOTk5atWqVbn3u337dhUXF6tFixZO1xcWFjp9r2vVqqVmzZo57bv0+5qTk6MDBw6oR48e5d6Hq48d4OuIJsAPXXHFFfrhhx9UVFSkkpIShYWF6fTp0zp9+rTCwsKUkJCgHTt2VMp9N2vWTPXr19f777+v3r17KzAwUJJ07NgxxcbGOr3XptTZvxFWur6Uw+FQSUmJJCkgIEDp6elavXq1lixZoqlTp+qZZ57RunXr1KRJE7fmPXbsmH73u9/p0UcfLbPt7DfMnx2L5bnQbLVq1ZIkvfvuu0pKSipzu7OdffwOh0OSrOM/3/wBAQHatGlTmX2FhYWVu9/SfRtjJEmhoaEXPDZXHzvA1xFNgB/64osvVFRUpB49emjSpEnq1KmT+vbtqwEDBqhXr15lfoCeKygoSMXFxU7XJSYmaubMmSooKLACYtWqVapRo4ZatmxprWvQoIHmzp2rG264QXfffbfmzJmjwMBAXXnllcrOzraeBXOXw+FQ165d1bVrV40dO1YJCQmaN2+e0tLSznubtWvXWgF05MgRffvtt9YzSFdeeaW++eYbNW/e3O2ZXJktLi5O3333nfr16+f2/st7XDp27Kji4mLl5OTouuuuc2u/derUUePGjZWRkaHu3buX2e6pxw7wdvz2HOCHEhISFBYWpkOHDum2225TfHy8duzYoTvvvFPNmzdXQkLCBW/fuHFjbdu2Tbt379Z//vMfFRUVqV+/fgoJCVH//v2VmZmppUuX6pFHHtH9999vvTRXKioqSl999ZV27dqle++9V6dPn1ZKSoqSk5PVp08fLVmyRN9//71Wr16tZ555Rhs3bnTpuNatW6eXXnpJGzdu1P79+zV37lwdPnzYCqDz+cMf/qCMjAxlZmZqwIABatCggfXS4ahRo7R69WoNHz5cW7Zs0Z49e/TZZ585/eaYJ2Z7/vnnNXHiRL3xxhv69ttvtX37ds2YMUOvvfaay/fRuHFjHTt2TBkZGfrPf/6j48ePq0WLFurXr58eeOABzZ07V1lZWVq/fr0mTpyohQsXurzv8ePH69VXX9Ubb7yhPXv26Ouvv9bUqVMlySOPHeAT7H5TFQB7/P3vfzfXXnutMcaYFStWmObNm7t825ycHHPTTTeZsLAwI8ksXbrUGGPMtm3bTPfu3U1ISIipV6+eGTx4sNNvrp37BvIDBw6YFi1amLvvvtucPn3a5Ofnm0ceecTExcWZwMBAEx8fb/r162f2799vjDnzRvD27ds7zTJ58mSTkJBgjDHmm2++MampqaZhw4YmODjYtGjRwkydOvW8x1H6RvDPP//cXHHFFSYoKMh06dLFbN261Wnd+vXrreOtXbu2adeunXnxxRet7QkJCU6/SVYeV2abNWuW6dChgwkKCjJ169Y13bp1M3PnzjXG/O+N4Js3b7bWHzlyxOn7b4wxDz/8sKlfv76RZMaNG2eMMebUqVNm7NixpnHjxiYwMNDExsaa22+/3Wzbts0YU/4byOfNm2fO/RExffp007JlS2sfjzzyiLXtYo8dUB04jPnvi9YA4GeWLVum7t2768iRI7z3BsBF8fIcAACAC4gmAAAAF/DyHAAAgAt4pgkAAMAFRBMAAIALiCYAAAAXEE0AAAAuIJoAAABcQDQBAAC4gGgCAABwAdEEAADggv8Pma9YkwPggi4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_list_len_pair_hist(\n",
    "['origin', 'subsampled'], '# tokens per sentence', 'count', sentences, subsampled);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e2596307-9df2-460f-b6ab-8002193fa3d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\"the\"的数量：之前=50770,之后=1970'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compare_counts(token):\n",
    "    return (f'\"{token}\"的数量：' f'之前={sum([l.count(token) for l in sentences])},' f'之后={sum([l.count(token) for l in subsampled])}')\n",
    "compare_counts('the')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e45d7154-1cc5-4ecd-9c22-e03e15d14c53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\"join\"的数量：之前=45,之后=45'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare_counts('join')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3efe3b25-f349-4877-8215-7dd9681ad290",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\"a\"的数量：之前=21196,之后=1300'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare_counts('a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9172bb05-3175-4584-ba3a-56e95442dd42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[], [392, 2115], [140, 5277, 3054, 1580]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corpus = [vocab[line] for line in subsampled]\n",
    "corpus[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8a1dd5f3-9695-4fb8-a5b9-bcad8c95e764",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据集 [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n",
      "中⼼词 0 的上下⽂词是 [1, 2]\n",
      "中⼼词 1 的上下⽂词是 [0, 2, 3]\n",
      "中⼼词 2 的上下⽂词是 [0, 1, 3, 4]\n",
      "中⼼词 3 的上下⽂词是 [2, 4]\n",
      "中⼼词 4 的上下⽂词是 [3, 5]\n",
      "中⼼词 5 的上下⽂词是 [3, 4, 6]\n",
      "中⼼词 6 的上下⽂词是 [4, 5]\n",
      "中⼼词 7 的上下⽂词是 [8, 9]\n",
      "中⼼词 8 的上下⽂词是 [7, 9]\n",
      "中⼼词 9 的上下⽂词是 [7, 8]\n"
     ]
    }
   ],
   "source": [
    "tiny_dataset = [list(range(7)), list(range(7, 10))]\n",
    "print('数据集', tiny_dataset)\n",
    "for center, context in zip(*d2l.get_centers_and_contexts(tiny_dataset, 2)):\n",
    "    print('中⼼词', center, '的上下⽂词是', context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "56f0b5f0-f166-4f79-8cd2-af83b67f6771",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# “中⼼词-上下⽂词对”的数量: 1501125'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_centers, all_contexts = d2l.get_centers_and_contexts(corpus, 5) \n",
    "f'# “中⼼词-上下⽂词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e3f1cca2-6f66-429f-b9eb-45df348ccc87",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7550dd57-5d4c-40e4-9540-4b21b811aa6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({3: 4422, 2: 3341, 1: 2237})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator = d2l.RandomGenerator([2, 3, 4])\n",
    "Counter([generator.draw() for _ in range(10000)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1019488e-a06c-48cd-9856-915ea5cc6dac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<unk>'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.to_tokens(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "056a11a6-f390-42b5-95ac-802479a7c077",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'the'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.to_tokens(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7ad5582c-ec6b-4245-93f2-5da90610beef",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_negatives = d2l.get_negatives(all_contexts, vocab, counter, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a4b0c4e6-6920-4c3a-a515-99b74b9177a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "centers = tensor([[1],\n",
      "        [1]]) torch.Size([2, 1])\n",
      "contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],\n",
      "        [2, 2, 2, 3, 3, 0]]) torch.Size([2, 6])\n",
      "masks = tensor([[1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 0]]) torch.Size([2, 6])\n",
      "labels = tensor([[1, 1, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 0, 0, 0]]) torch.Size([2, 6])\n"
     ]
    }
   ],
   "source": [
    "x_1 = (1, [2, 2], [3, 3, 3, 3])\n",
    "x_2 = (1, [2, 2, 2], [3, 3])\n",
    "batch =  d2l.batchify((x_1, x_2))\n",
    "\n",
    "names = ['centers', 'contexts_negatives', 'masks', 'labels']\n",
    "for name, data in zip(names, batch):\n",
    "    print(name, '=', data,data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0c97f8c0-e3fb-4297-a629-1c5a91d3c8f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "centers shape: torch.Size([512, 1])\n",
      "contexts_negatives shape: torch.Size([512, 60])\n",
      "masks shape: torch.Size([512, 60])\n",
      "labels shape: torch.Size([512, 60])\n"
     ]
    }
   ],
   "source": [
    "data_iter, vocab = d2l.load_data_ptb(512, 5, 5)\n",
    "for batch in data_iter:\n",
    "    for name, data in zip(names, batch):\n",
    "        print(name, 'shape:', data.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bad48d28-79ac-4a47-984b-c795717c0abf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.2.2+cu118'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cf0aac29-48b3-48bc-bf3f-756a31e3f5a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)\n"
     ]
    }
   ],
   "source": [
    "embed = torch.nn.Embedding(num_embeddings=20, embedding_dim=4)\n",
    "print(f'Parameter embedding_weight ({embed.weight.shape}, ' f'dtype={embed.weight.dtype})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "39b9358a-b11e-4ec0-b6f0-5e93281be871",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.0870, -2.4230,  0.1769, -0.6028],\n",
       "        [-1.0941, -0.3099, -0.5415, -1.4478],\n",
       "        [ 1.3356,  0.6281, -0.1886,  0.3115],\n",
       "        [-0.2177, -0.0104, -1.8339, -1.0206],\n",
       "        [-0.3971, -0.3391,  0.4153,  2.2809],\n",
       "        [ 0.1145, -0.4256, -0.6927, -0.2498],\n",
       "        [-1.3401, -0.6374, -1.3471,  1.8846],\n",
       "        [-1.4583, -1.4211, -0.7252,  0.6376],\n",
       "        [ 1.1957,  1.7431,  2.1507,  0.6965],\n",
       "        [ 0.9710, -1.2405, -0.6184,  0.4349],\n",
       "        [ 1.0757, -0.2949, -0.1132, -3.7262],\n",
       "        [-0.7974, -1.2629,  0.8216,  1.5593],\n",
       "        [ 0.3920,  0.9417, -0.2181,  0.3756],\n",
       "        [ 0.4057,  0.4345, -0.6123, -0.0419],\n",
       "        [-0.2562, -1.2691,  0.0688, -0.4800],\n",
       "        [-1.9576,  0.2081,  1.1593, -0.1719],\n",
       "        [-0.7271,  0.7987, -0.7920, -1.8672],\n",
       "        [ 1.0119, -0.1565, -1.5041, -0.7412],\n",
       "        [ 0.7082, -1.1592, -0.8617, -1.2090],\n",
       "        [ 0.2953, -0.9249, -1.0332, -0.0245]], requires_grad=True)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embed.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2e61b435-32d5-46b8-acd7-fe71da806393",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.0941, -0.3099, -0.5415, -1.4478],\n",
       "         [ 1.3356,  0.6281, -0.1886,  0.3115],\n",
       "         [-0.2177, -0.0104, -1.8339, -1.0206]],\n",
       "\n",
       "        [[-0.3971, -0.3391,  0.4153,  2.2809],\n",
       "         [ 0.1145, -0.4256, -0.6927, -0.2498],\n",
       "         [-1.3401, -0.6374, -1.3471,  1.8846]]], grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
    "embed(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bdf1e9e9-c8be-4e38-94f6-ed88b9a41f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n",
    "    v = embed_v(center)\n",
    "    u = embed_u(contexts_and_negatives)\n",
    "    pred = torch.bmm(v, u.permute(0, 2, 1))\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e8394dd1-4a8c-42f4-89c4-ec5ff86d34e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入张量形状: torch.Size([2, 3, 4])\n",
      "第二个张量形状: torch.Size([2, 4, 5])\n",
      "输出张量形状: torch.Size([2, 3, 5])\n",
      "输出张量内容:\n",
      " tensor([[[ 110.,  120.,  130.,  140.,  150.],\n",
      "         [ 246.,  272.,  298.,  324.,  350.],\n",
      "         [ 382.,  424.,  466.,  508.,  550.]],\n",
      "\n",
      "        [[1678., 1736., 1794., 1852., 1910.],\n",
      "         [2134., 2208., 2282., 2356., 2430.],\n",
      "         [2590., 2680., 2770., 2860., 2950.]]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# 创建输入张量\n",
    "batch_size = 2\n",
    "n = 3\n",
    "m = 4\n",
    "p = 5\n",
    "\n",
    "# 初始化两个张量\n",
    "input_tensor = torch.tensor([\n",
    "    [[1, 2, 3, 4],\n",
    "     [5, 6, 7, 8],\n",
    "     [9, 10, 11, 12]],\n",
    "    [[13, 14, 15, 16],\n",
    "     [17, 18, 19, 20],\n",
    "     [21, 22, 23, 24]]\n",
    "], dtype=torch.float32)\n",
    "\n",
    "mat2_tensor = torch.tensor([\n",
    "    [[1, 2, 3, 4, 5],\n",
    "     [6, 7, 8, 9, 10],\n",
    "     [11, 12, 13, 14, 15],\n",
    "     [16, 17, 18, 19, 20]],\n",
    "    [[21, 22, 23, 24, 25],\n",
    "     [26, 27, 28, 29, 30],\n",
    "     [31, 32, 33, 34, 35],\n",
    "     [36, 37, 38, 39, 40]]\n",
    "], dtype=torch.float32)\n",
    "\n",
    "# 执行批量矩阵乘法\n",
    "output = torch.bmm(input_tensor, mat2_tensor)\n",
    "\n",
    "print(\"输入张量形状:\", input_tensor.shape)\n",
    "print(\"第二个张量形状:\", mat2_tensor.shape)\n",
    "print(\"输出张量形状:\", output.shape)\n",
    "print(\"输出张量内容:\\n\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "f5cef831-fddc-4b60-962d-b9dc600baced",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1678"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "13*21 + 14*26 + 15*31 + 16*36"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8574eab8-20fa-4a60-b384-18c1f58a7511",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 4])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "skip_gram(torch.ones((2, 1), dtype=torch.long),\n",
    "torch.ones((2, 4), dtype=torch.long), embed, embed).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6aa949f-ebd7-4a1e-af5c-e4e5ec6558d8",
   "metadata": {},
   "source": [
    "embedding层的权重矩阵目前还只是一个随机数组成的矩阵，我们最终目标就是通过训练使得这个权重矩阵可用与表示词的向量映射。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "de090de1-602d-4025-9010-cac7923b8b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SigmoidBCELoss(torch.nn.Module):\n",
    "    # 带掩码的⼆元交叉熵损失\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, inputs, target, mask=None):\n",
    "        out = torch.nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction=\"none\")\n",
    "        return out.mean(dim=1)\n",
    "\n",
    "loss = SigmoidBCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a530d960-e6c9-4bff-8321-02fd98905d28",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9352, 1.8462])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)\n",
    "label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])\n",
    "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])\n",
    "loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "915f683b-6b3d-4805-b1b7-5919d6714611",
   "metadata": {},
   "source": [
    "因为原先是除以dim=1真个的均值，会缩小损失值，所以先还原，再除以实际计算了损失值的个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9ec2406d-0602-4ecc-84e9-72bf65022d01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 2.])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask.shape[1] / mask.sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1454ae1-de5f-43c2-ba68-721d91fb6b15",
   "metadata": {},
   "source": [
    "负采样修改了原⽬标函数。给定中⼼词wc的上下⽂窗⼝，任意上下⽂词wo来⾃该上下⽂窗⼝的被认为是由下\n",
    "式建模概率的事件：$$P(D = 1 | w_c, w_o) = \\sigma(u_o^{\\top} v_c)$$其中σ使⽤了sigmoid激活函数的定义：$$\\sigma(x)=\\frac{1}{1 + e^{-x}}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8cfac73-2703-4da7-a58b-6c2dfcff06b4",
   "metadata": {},
   "source": [
    "以下是对上述内容进行整理后的结果，包含公式的规范 Markdown 表示以及对每一步的简要解释。\n",
    "\n",
    "**1. 负采样对联合概率的重写**\n",
    "\n",
    "原始联合概率（仅涉及正例）被重写为：\n",
    "$$\n",
    "\\prod_{t = 1}^{T} \\prod_{\\substack{-m\\leq j\\leq m\\\\ j\\neq 0}} P(w^{(t + j)} | w^{(t)}) \\quad (14.2.4)\n",
    "$$\n",
    "这里 $T$ 表示文本序列的总时间步数，$m$ 是上下文窗口的大小，$w^{(t)}$ 表示时间步 $t$ 处的词。\n",
    "\n",
    "**2. 通过事件近似条件概率**\n",
    "\n",
    "通过事件 $S, N_1, \\cdots, N_K$ 近似条件概率：\n",
    "$$\n",
    "P(w^{(t + j)} | w^{(t)}) = P(D = 1 | w^{(t)}, w^{(t + j)}) \\prod_{k = 1, w_k\\sim P(w)}^{K} P(D = 0 | w^{(t)}, w_k) \\quad (14.2.5)\n",
    "$$\n",
    "其中 $D$ 是一个二元变量，$D = 1$ 表示正例（$w^{(t + j)}$ 是 $w^{(t)}$ 的真实上下文词），$D = 0$ 表示负例（$w_k$ 是从噪声分布 $P(w)$ 中采样得到的噪声词），$K$ 是负采样的数量。\n",
    "\n",
    "**3. 条件概率的对数损失推导**\n",
    "\n",
    "设 $i_t$ 和 $h_k$ 分别表示词 $w^{(t)}$ 和噪声词 $w_k$ 在文本序列的时间步 $t$ 处的索引。对 $(14.2.5)$ 中条件概率取对数损失：\n",
    "\n",
    "**第一步：展开对数损失**\n",
    "\n",
    "$$\n",
    "-\\log P(w^{(t + j)} | w^{(t)}) = -\\log P(D = 1 | w^{(t)}, w^{(t + j)}) - \\sum_{k = 1, w_k\\sim P(w)}^{K} \\log P(D = 0 | w^{(t)}, w_k)\n",
    "$$\n",
    "\n",
    "**第二步：代入 sigmoid 形式**\n",
    "\n",
    "已知 $P(D = 1 | w^{(t)}, w^{(t + j)})=\\sigma(u_{i_{t + j}}^{\\top} v_{i_t})$，$P(D = 0 | w^{(t)}, w_k)=1 - \\sigma(u_{h_k}^{\\top} v_{i_t})$，代入上式可得：\n",
    "$$\n",
    "-\\log P(w^{(t + j)} | w^{(t)}) = -\\log \\sigma (u_{i_{t + j}}^{\\top} v_{i_t}) - \\sum_{k = 1, w_k\\sim P(w)}^{K} \\log (1 - \\sigma (u_{h_k}^{\\top} v_{i_t}))\n",
    "$$\n",
    "\n",
    "**第三步：利用 sigmoid 函数性质 $1 - \\sigma(x)=\\sigma(-x)$ 进行化简**\n",
    "$$\n",
    "-\\log P(w^{(t + j)} | w^{(t)}) = -\\log \\sigma (u_{i_{t + j}}^{\\top} v_{i_t}) - \\sum_{k = 1, w_k\\sim P(w)}^{K} \\log \\sigma (-u_{h_k}^{\\top} v_{i_t})\n",
    "$$\n",
    "\n",
    "#### **总结**\n",
    "经过整理，负采样下条件概率的对数损失公式为：\n",
    "$$\n",
    "-\\log P(w^{(t + j)} | w^{(t)}) = -\\log \\sigma (u_{i_{t + j}}^{\\top} v_{i_t}) - \\sum_{k = 1, w_k\\sim P(w)}^{K} \\log \\sigma (-u_{h_k}^{\\top} v_{i_t})\n",
    "$$\n",
    "这个公式用于在负采样的词嵌入模型训练中计算损失，通过最小化这个损失来学习词的嵌入向量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d9a69ca0-2b2f-467e-b323-fe37bbc229b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9352\n",
      "1.8462\n"
     ]
    }
   ],
   "source": [
    "def sigmd(x):\n",
    "    return -math.log(1 / (1 + math.exp(-x)))\n",
    "print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')\n",
    "print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "95b5983e-02e2-408b-970f-bef9da13bf6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_size = 100\n",
    "net = torch.nn.Sequential(torch.nn.Embedding(num_embeddings=len(vocab),\n",
    "embedding_dim=embed_size),\n",
    "torch.nn.Embedding(num_embeddings=len(vocab),\n",
    "embedding_dim=embed_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "b5e5dcc4-e1dd-4cfb-9106-2f9cacd31596",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()):\n",
    "    def init_weights(m):\n",
    "        if type(m) == torch.nn.Embedding:\n",
    "            torch.nn.init.xavier_uniform_(m.weight)\n",
    "    net.apply(init_weights)\n",
    "    net = net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[1, num_epochs])\n",
    "    # 规范化的损失之和，规范化的损失数\n",
    "    metric = d2l.Accumulator(2)\n",
    "    for epoch in range(num_epochs):\n",
    "        timer, num_batches = d2l.Timer(), len(data_iter)\n",
    "        for i, batch in enumerate(data_iter):\n",
    "            optimizer.zero_grad()\n",
    "            center, context_negative, mask, label = [\n",
    "            data.to(device) for data in batch]\n",
    "            pred = skip_gram(center, context_negative, net[0], net[1])\n",
    "            l = (loss(pred.reshape(label.shape).float(), label.float(), mask)/ mask.sum(axis=1) * mask.shape[1])\n",
    "            l.sum().backward()\n",
    "            optimizer.step()\n",
    "            metric.add(l.sum(), l.numel())\n",
    "            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
    "                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[1],))\n",
    "    print(f'loss {metric[0] / metric[1]:.3f}, ' f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "98426408-e538-45f4-9a38-0c57efdfa4e2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.410, 34741.8 tokens/sec on cuda:0\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"255.825pt\" height=\"183.35625pt\" viewBox=\"0 0 255.825 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2025-04-10T16:29:25.531868</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.8.3, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 183.35625 \n",
       "L 255.825 183.35625 \n",
       "L 255.825 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 50.14375 145.8 \n",
       "L 245.44375 145.8 \n",
       "L 245.44375 7.2 \n",
       "L 50.14375 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path d=\"M 50.14375 145.8 \n",
       "L 50.14375 7.2 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path id=\"m393b4ba3f7\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m393b4ba3f7\" x=\"50.14375\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(46.9625 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path d=\"M 98.96875 145.8 \n",
       "L 98.96875 7.2 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m393b4ba3f7\" x=\"98.96875\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 2 -->\n",
       "      <g transform=\"translate(95.7875 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path d=\"M 147.79375 145.8 \n",
       "L 147.79375 7.2 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m393b4ba3f7\" x=\"147.79375\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 3 -->\n",
       "      <g transform=\"translate(144.6125 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
       "Q 3050 2419 3304 2112 \n",
       "Q 3559 1806 3559 1356 \n",
       "Q 3559 666 3084 287 \n",
       "Q 2609 -91 1734 -91 \n",
       "Q 1441 -91 1130 -33 \n",
       "Q 819 25 488 141 \n",
       "L 488 750 \n",
       "Q 750 597 1062 519 \n",
       "Q 1375 441 1716 441 \n",
       "Q 2309 441 2620 675 \n",
       "Q 2931 909 2931 1356 \n",
       "Q 2931 1769 2642 2001 \n",
       "Q 2353 2234 1838 2234 \n",
       "L 1294 2234 \n",
       "L 1294 2753 \n",
       "L 1863 2753 \n",
       "Q 2328 2753 2575 2939 \n",
       "Q 2822 3125 2822 3475 \n",
       "Q 2822 3834 2567 4026 \n",
       "Q 2313 4219 1838 4219 \n",
       "Q 1578 4219 1281 4162 \n",
       "Q 984 4106 628 3988 \n",
       "L 628 4550 \n",
       "Q 988 4650 1302 4700 \n",
       "Q 1616 4750 1894 4750 \n",
       "Q 2613 4750 3031 4423 \n",
       "Q 3450 4097 3450 3541 \n",
       "Q 3450 3153 3228 2886 \n",
       "Q 3006 2619 2597 2516 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-33\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path d=\"M 196.61875 145.8 \n",
       "L 196.61875 7.2 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m393b4ba3f7\" x=\"196.61875\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 4 -->\n",
       "      <g transform=\"translate(193.4375 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-34\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path d=\"M 245.44375 145.8 \n",
       "L 245.44375 7.2 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m393b4ba3f7\" x=\"245.44375\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(242.2625 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <g transform=\"translate(132.565625 174.076563) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
       "L 1159 -1331 \n",
       "L 581 -1331 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2969 \n",
       "Q 1341 3281 1617 3432 \n",
       "Q 1894 3584 2278 3584 \n",
       "Q 2916 3584 3314 3078 \n",
       "Q 3713 2572 3713 1747 \n",
       "Q 3713 922 3314 415 \n",
       "Q 2916 -91 2278 -91 \n",
       "Q 1894 -91 1617 61 \n",
       "Q 1341 213 1159 525 \n",
       "z\n",
       "M 3116 1747 \n",
       "Q 3116 2381 2855 2742 \n",
       "Q 2594 3103 2138 3103 \n",
       "Q 1681 3103 1420 2742 \n",
       "Q 1159 2381 1159 1747 \n",
       "Q 1159 1113 1420 752 \n",
       "Q 1681 391 2138 391 \n",
       "Q 2594 391 2855 752 \n",
       "Q 3116 1113 3116 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
       "Q 1497 3097 1228 2736 \n",
       "Q 959 2375 959 1747 \n",
       "Q 959 1119 1226 758 \n",
       "Q 1494 397 1959 397 \n",
       "Q 2419 397 2687 759 \n",
       "Q 2956 1122 2956 1747 \n",
       "Q 2956 2369 2687 2733 \n",
       "Q 2419 3097 1959 3097 \n",
       "z\n",
       "M 1959 3584 \n",
       "Q 2709 3584 3137 3096 \n",
       "Q 3566 2609 3566 1747 \n",
       "Q 3566 888 3137 398 \n",
       "Q 2709 -91 1959 -91 \n",
       "Q 1206 -91 779 398 \n",
       "Q 353 888 353 1747 \n",
       "Q 353 2609 779 3096 \n",
       "Q 1206 3584 1959 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
       "L 3122 2828 \n",
       "Q 2878 2963 2633 3030 \n",
       "Q 2388 3097 2138 3097 \n",
       "Q 1578 3097 1268 2742 \n",
       "Q 959 2388 959 1747 \n",
       "Q 959 1106 1268 751 \n",
       "Q 1578 397 2138 397 \n",
       "Q 2388 397 2633 464 \n",
       "Q 2878 531 3122 666 \n",
       "L 3122 134 \n",
       "Q 2881 22 2623 -34 \n",
       "Q 2366 -91 2075 -91 \n",
       "Q 1284 -91 818 406 \n",
       "Q 353 903 353 1747 \n",
       "Q 353 2603 823 3093 \n",
       "Q 1294 3584 2113 3584 \n",
       "Q 2378 3584 2631 3529 \n",
       "Q 2884 3475 3122 3366 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 4863 \n",
       "L 1159 4863 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path d=\"M 50.14375 114.724213 \n",
       "L 245.44375 114.724213 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <defs>\n",
       "       <path id=\"m7c8b783380\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m7c8b783380\" x=\"50.14375\" y=\"114.724213\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.45 -->\n",
       "      <g transform=\"translate(20.878125 118.523432) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path d=\"M 50.14375 83.56967 \n",
       "L 245.44375 83.56967 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m7c8b783380\" x=\"50.14375\" y=\"83.56967\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.50 -->\n",
       "      <g transform=\"translate(20.878125 87.368889) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path d=\"M 50.14375 52.415128 \n",
       "L 245.44375 52.415128 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m7c8b783380\" x=\"50.14375\" y=\"52.415128\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.55 -->\n",
       "      <g transform=\"translate(20.878125 56.214347) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_17\">\n",
       "      <path d=\"M 50.14375 21.260586 \n",
       "L 245.44375 21.260586 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_18\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m7c8b783380\" x=\"50.14375\" y=\"21.260586\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 0.60 -->\n",
       "      <g transform=\"translate(20.878125 25.059805) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
       "Q 1688 2584 1439 2293 \n",
       "Q 1191 2003 1191 1497 \n",
       "Q 1191 994 1439 701 \n",
       "Q 1688 409 2113 409 \n",
       "Q 2538 409 2786 701 \n",
       "Q 3034 994 3034 1497 \n",
       "Q 3034 2003 2786 2293 \n",
       "Q 2538 2584 2113 2584 \n",
       "z\n",
       "M 3366 4563 \n",
       "L 3366 3988 \n",
       "Q 3128 4100 2886 4159 \n",
       "Q 2644 4219 2406 4219 \n",
       "Q 1781 4219 1451 3797 \n",
       "Q 1122 3375 1075 2522 \n",
       "Q 1259 2794 1537 2939 \n",
       "Q 1816 3084 2150 3084 \n",
       "Q 2853 3084 3261 2657 \n",
       "Q 3669 2231 3669 1497 \n",
       "Q 3669 778 3244 343 \n",
       "Q 2819 -91 2113 -91 \n",
       "Q 1303 -91 875 529 \n",
       "Q 447 1150 447 2328 \n",
       "Q 447 3434 972 4092 \n",
       "Q 1497 4750 2381 4750 \n",
       "Q 2619 4750 2861 4703 \n",
       "Q 3103 4656 3366 4563 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-36\" x=\"95.410156\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_11\">\n",
       "     <!-- loss -->\n",
       "     <g transform=\"translate(14.798437 86.157813) rotate(-90) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_19\">\n",
       "    <path d=\"M 11.069197 13.5 \n",
       "L 20.819644 60.606362 \n",
       "L 30.570091 79.375847 \n",
       "L 40.320538 89.371299 \n",
       "L 50.070985 95.766561 \n",
       "L 50.14375 95.780872 \n",
       "L 59.894197 100.575799 \n",
       "L 69.644644 104.296218 \n",
       "L 79.395091 107.401767 \n",
       "L 89.145538 110.056764 \n",
       "L 98.895985 112.415367 \n",
       "L 98.96875 112.422092 \n",
       "L 108.719197 114.947604 \n",
       "L 118.469644 117.189794 \n",
       "L 128.220091 119.238253 \n",
       "L 137.970538 121.097596 \n",
       "L 147.720985 122.807598 \n",
       "L 147.79375 122.812944 \n",
       "L 157.544197 124.944671 \n",
       "L 167.294644 126.854697 \n",
       "L 177.045091 128.599618 \n",
       "L 186.795538 130.194698 \n",
       "L 196.545985 131.65231 \n",
       "L 196.61875 131.657135 \n",
       "L 206.369197 133.56534 \n",
       "L 216.119644 135.256835 \n",
       "L 225.870091 136.800366 \n",
       "L 235.620538 138.211284 \n",
       "L 245.370985 139.496168 \n",
       "L 245.44375 139.5 \n",
       "\" clip-path=\"url(#p5e94c1c40d)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 50.14375 145.8 \n",
       "L 50.14375 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 245.44375 145.8 \n",
       "L 245.44375 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 50.14375 145.8 \n",
       "L 245.44375 145.8 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 50.14375 7.2 \n",
       "L 245.44375 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p5e94c1c40d\">\n",
       "   <rect x=\"50.14375\" y=\"7.2\" width=\"195.3\" height=\"138.6\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 350x250 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lr, num_epochs = 0.002, 5\n",
    "train(net, data_iter, lr, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "bb3a3429-6175-48ec-a819-dd0fa684567c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6719, 100])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "6968229b-468a-4b22-8252-c45a5d0d1e8c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a7e3e225-2571-4c22-9eee-3623c60e9270",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_similar_tokens(query_token, k, embed):\n",
    "    W = embed.weight.data\n",
    "    x = W[vocab[query_token]]\n",
    "    # 计算余弦相似性。增加1e-9以获得数值稳定性\n",
    "    cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)\n",
    "    topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')\n",
    "    for i in topk[1:]: # 删除输⼊词\n",
    "        print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "c8f27f9c-460f-43bd-b8b0-5bb7903e8c6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.602: optical\n",
      "cosine sim=0.544: computer\n",
      "cosine sim=0.535: portable\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('apple', 3, net[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4469090a-7000-4bf6-9841-532c670dc531",
   "metadata": {},
   "source": [
    "fastText可以被认为是⼦词级跳元模型，⽽⾮学习词级向量表⽰，其中每个中⼼词由其⼦词级向量之和表⽰。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8de59eaa-7495-4632-96d0-e8c1514e1b6f",
   "metadata": {},
   "source": [
    "字节对编码执⾏训练数据集的统计分析，以发现单词内的公共符号，诸如任意⻓度的连续字符。从⻓度为1的符号开始，字节对编码迭代地合并最频繁的连续符号对以产⽣新的更⻓的符号。请注意，为提⾼效率，不考虑跨越单词边界的对。最后，我们可以使⽤像⼦词这样的符号来切分单词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8d2efc89-4150-48a9-9baf-93a770ebfd45",
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',\n",
    "           'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "7e3e9176-f533-47b0-998f-0a609d4bd586",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}\n",
    "token_freqs = {}\n",
    "for token, freq in raw_token_freqs.items():\n",
    "    token_freqs[' '.join(list(token))] = raw_token_freqs[token]\n",
    "token_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "04666059-eda9-432b-83ea-6e6549be7f8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_freq_pair(token_freqs):\n",
    "    pairs = collections.defaultdict(int)\n",
    "    for token, freq in token_freqs.items():\n",
    "        symbols = token.split()\n",
    "        for i in range(len(symbols) - 1):\n",
    "            # “pairs”的键是两个连续符号的元组\n",
    "            pairs[symbols[i], symbols[i + 1]] += freq\n",
    "    if len(pairs)==0: return ''\n",
    "    return max(pairs, key=pairs.get) # 具有最⼤值的“pairs”键"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "494a34c4-92c3-4854-a1cb-2a5fae48b23e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_symbols(max_freq_pair, token_freqs, symbols):\n",
    "    if not max_freq_pair:return token_freqs\n",
    "    symbols.append(''.join(max_freq_pair))\n",
    "    new_token_freqs = dict()\n",
    "    for token, freq in token_freqs.items():\n",
    "        new_token = token.replace(' '.join(max_freq_pair),\n",
    "        ''.join(max_freq_pair))\n",
    "        new_token_freqs[new_token] = token_freqs[token]\n",
    "    return new_token_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "d9879351-5143-462a-87c9-f9482b68d5d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "合并# 1: ('t', 'a')\n",
      "合并# 2: ('ta', 'l')\n",
      "合并# 3: ('tal', 'l')\n",
      "合并# 4: ('f', 'a')\n",
      "合并# 5: ('fa', 's')\n",
      "合并# 6: ('fas', 't')\n",
      "合并# 7: ('e', 'r')\n",
      "合并# 8: ('er', '_')\n",
      "合并# 9: ('tall', '_')\n",
      "合并# 10: ('fast', '_')\n"
     ]
    }
   ],
   "source": [
    "num_merges = 10\n",
    "for i in range(num_merges):\n",
    "    max_freq_pair = get_max_freq_pair(token_freqs)\n",
    "    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)\n",
    "    print(f'合并# {i+1}:',max_freq_pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "19fd4cb3-2dae-4f35-8f0e-ac81969b72be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'fast_': 4, 'fast er_': 3, 'tall_': 5, 'tall er_': 4}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "5725173a-e2bd-4f6c-997f-d11c0e0d3fcb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_']\n"
     ]
    }
   ],
   "source": [
    "print(symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1c93a586-1a0f-4d39-8be2-9ef5db6dc9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def segment_BPE(tokens, symbols):\n",
    "    outputs = []\n",
    "    for token in tokens:\n",
    "        start, end = 0, len(token)\n",
    "        cur_output = []\n",
    "        # 具有符号中可能最⻓⼦字的词元段\n",
    "        while start < len(token) and start < end:\n",
    "            if token[start: end] in symbols:\n",
    "                cur_output.append(token[start: end])\n",
    "                start = end\n",
    "                end = len(token)\n",
    "            else:\n",
    "                end -= 1\n",
    "        if start < len(token):\n",
    "            cur_output.append('[UNK]')\n",
    "        outputs.append(' '.join(cur_output))\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "7fe5e3af-4549-4c76-a67c-5f05ac0657f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['tall e s t _', 'fa t t er_']\n"
     ]
    }
   ],
   "source": [
    "tokens = ['tallest_', 'fatter_']\n",
    "print(segment_BPE(tokens, symbols))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "7c37c82d-d1ee-4446-a897-da6410871d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "import d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cbe2c98-9ca3-412c-8eaa-fc59ab7bfdee",
   "metadata": {},
   "source": [
    "预训练 GloVe 嵌入就是在大规模文本语料库上预先训练好的 GloVe 模型所得到的词向量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "009b5fb4-27bb-4bf8-8b7b-9bc5338aa677",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
